#This script calculates the changepoints for the cosine similarity scores
#it does this for each version (sample size)
#It also calculates the F-statistics for each country and version (sample size)
#It does this for Egypt and Tunisia
#It produces Figure 3 and Figure H1 in the paper
library(ggplot2)
library(dplyr)
library(strucchange)
library(ggthemes)
library(gridExtra)
library(ggridges)
library(tidyr)

# Define the list of countries
countries <- c("masress", "turess")
# Versions to process
sample_sizes <- c(1e4, 5e4, 1e5, 5e5, 1e6, 1.5e6)
formatted_sample_sizes <- sapply(sample_sizes, function(x) format(x, scientific = FALSE))
versions <- paste0(formatted_sample_sizes, "30k")

# Iterate through both versions and read the corresponding data
all_data <- bind_rows(lapply(versions, function(version) {
  bind_rows(lapply(countries, function(country) {
    cos_simsdf <-
      readRDS(paste0("data/output/cos_sims/", country, "/", "cos_simsdf_all", version, ".rds"))
    cos_simsdf %>%
      mutate(group = as.Date(group)) %>%
      arrange(group) %>%
      rename(yearwk = group,
             cos_sim = val) %>%
      mutate(country = country, version = version) # Add columns for country and version
  }))
}))

country_map <- c(
  masress = "Egypt", 
  turess = "Tunisia")

# Get factor labels for versions and define colors
ordered_versions <- c("1000030k", "5000030k", "10000030k", "50000030k", "100000030k", "150000030k")
colors_for_versions <- c("#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00")

# Initialize a list to store breakpoints and F statistics
breakpoints_list <- list()

# Loop through each country
for (current_country in unique(all_data$country)) {
  # Loop through each version
  for (current_version in unique(all_data$version)) {
    # Subset the data for the current country and version
    dat <- all_data %>%
      filter(country == current_country, version == current_version) %>%
      select(yearwk, cos_sim)
    
    # Convert time to a running integer
    dat$time <- as.integer(dat$yearwk - min(dat$yearwk)) + 1
    
    # Fit the model
    dat.model <- dat$cos_sim ~ dat$yearwk
    ols <- efp(dat.model, data=dat, type="OLS-CUSUM") 
    
    # F statistics
    fs <- Fstats(dat.model, data=dat)
    #get f-stats
    ts_object <- fs[["Fstats"]]
    nobs <- frequency(ts_object)
    n_prepend <- (ts_object)[3]
    n_prepend <- start(ts_object)[2] -1
    n_append <- nobs - (end(ts_object)[2])
    # Insert NA values
    time_series_padded <- c(rep(NA, n_prepend), ts_object, rep(NA, n_append))
    
    # Breakpoints
    bp <- breakpoints(dat.model, data=dat)
    sumbp <- summary(bp)
    bps <- sumbp[["breakpoints"]]
    
    bpints <- bps[1, !is.na(bps[1,])]
    dates <- dat$yearwk[bpints]
    
    # Storing the breakpoints and F statistics in the list
    key <- paste(current_country, current_version, sep="_")
    breakpoints_list[[key]] <- list(breakpoints = bpints, fstats = time_series_padded, dates = dates)
  }
}



# Define the countries
countries <- c("masress", "turess")

# Define a small offset value for the vertical lines
offset_value <- 10

# Initialize a list to store the combined plots for each country
combined_plots_list <- list()

# Loop through each country
for (current_country in countries) {
  # Prepare the data for the current country
  country_data <- all_data %>%
    mutate(country_name = country_map[current_country]) %>%
    filter(country == current_country) %>%
    mutate(version = factor(version, levels = ordered_versions)) # Set version as a factor with ordered levels
  
  # Initialize an empty list to store fstats for each version
  fstats_list <- list()
  
  for (version in ordered_versions) {
    # Check if the breakpoint exists for this country-version combination
    key <- paste(current_country, version, sep = "_")
    if (!is.null(breakpoints_list[[key]][["fstats"]])) {
      fstats_list[[version]] <- breakpoints_list[[key]][["fstats"]]
    } else {
      # If no fstats data, use NA
      fstats_list[[version]] <- rep(NA, nrow(country_data))
    }
  }
  
  # Stack all fstats values
  all_fstats <- unlist(fstats_list, use.names = FALSE)
  
  # Check if the length matches
  if (length(all_fstats) == nrow(country_data)) {
    # Add the stacked fstats as a new column to country_data
    country_data$fstats <- all_fstats
  } else {
    break("The length of stacked fstats does not match the number of rows in country_data")
  }
  
  # Initialize a string to store breakpoint dates for the caption
  caption_text <- "Breakpoints: "
  
  # Plot for cos_sim
  p <- ggplot(country_data, aes(x = yearwk, y = cos_sim, color = version)) +
    geom_line(alpha = 0.1, size = 1) +
    scale_color_manual(values = setNames(colors_for_versions, ordered_versions)) +
    labs(x = "Year-week", y = "Cosine Similarity") +
    ylim(-0.2, 0.2) +
    theme_tufte(base_family = "Helvetica") +
    theme(axis.line.y = element_line(colour = "white"),
          axis.ticks.y = element_blank(),
          axis.text.y = element_text(colour = "white"),
          axis.title.y = element_text(colour = "white"),
          legend.position = "none")
  
  # Plot for fstats
  p_fstats <- ggplot(country_data, aes(x = yearwk, y = fstats, color = version)) +
    geom_line(alpha = 1, size = 1) +
    scale_y_continuous(name = "F-statistic") +
    xlab("") +
    labs(title = country_map[current_country], color = "Version") +
    scale_color_manual(values = setNames(colors_for_versions, ordered_versions)) +
    theme_tufte(base_family = "Helvetica") +
    theme(legend.position = c(0.9, 0.9),  
          legend.direction = "vertical",  
          legend.box = "vertical",        
          legend.box.background = element_rect(color = "black", size = .1), 
          legend.key.size = unit(.5, "lines"),
          legend.text = element_text(size = 7),
          legend.title = element_text(size = 8)) 
  
  # Initialize a variable to track the offset
  offset <- 0
  version_count <- 0
  
  # Loop through each version
  for (version in ordered_versions) {
    # Check if the breakpoint exists for this country-version combination
    key <- paste(current_country, version, sep="_")
    if (!is.null(breakpoints_list[[key]]$dates)) {
      # Retrieve the breakpoint date
      bp_date <- breakpoints_list[[key]]$dates
      
      # Determine the color for the version
      version_color <- colors_for_versions[which(ordered_versions == version)]
      
      # Add breakpoint line to the plot with an offset
      p <- p + geom_vline(xintercept = as.numeric(bp_date) + offset, color = version_color, alpha = 1, size = 1)
      
      # Update the caption text with the breakpoint date
      caption_text <- paste0(caption_text, ifelse(version_count == 3, "\n", ""), version, ": ", format(bp_date, "%Y-%m-%d"), " ")
      version_count <- version_count + 1
      
      # Update the offset for the next version
      offset <- offset + offset_value
    }
  }
  
  # Add the caption with breakpoint dates
  p <- p + labs(caption = caption_text) +
    theme(plot.caption = element_text(hjust = 0, margin = margin(t = 10))) # Left align
  
  # Combined plot for the current country (stacking p and p_fstats)
  combined_plot <- grid.arrange(p_fstats, p, ncol = 1) # Stacks p and p_fstats vertically
  
  # Store the combined plot in the list
  combined_plots_list[[current_country]] <- combined_plot
}


# Define the layout matrix with relative heights for rows
layout_matrix <- rbind(c(1, 1), c(2, 2))  # You can adjust the row heights here

# Define the heights for the rows
heights <- unit(c(1, 2), c("null", "null"))  # Adjust the heights as needed

# Arrange the combined plots side by side for all countries
all_countries_plot <- do.call(grid.arrange, c(combined_plots_list, ncol = length(countries)))

ggsave("plots/fig3.png",
       all_countries_plot,
       units = "in",
       width = 12,
       height = 5,
       dpi = 300
)


# Create a data frame from breakpoints_list with a "key" column
fstats_df <- do.call(rbind, lapply(names(breakpoints_list), function(name) {
  data.frame(
    fstats = breakpoints_list[[name]]$fstats,
    key = name,
    stringsAsFactors = FALSE
  )
}))

# Split the "key" column into separate "country" and "version" columns
fstats_df <- separate(fstats_df, key, into = c("country", "version"), sep = "_", remove = FALSE)

# Map internal country codes to facet titles
country_map <- c(masress = "Egypt", turess = "Tunisia")
fstats_df$country <- country_map[fstats_df$country]

# Set the factor levels for version in descending order
ordered_versions <- c("150000030k", "100000030k", "50000030k", "10000030k", "5000030k", "1000030k")
fstats_df$version <- factor(fstats_df$version, levels = ordered_versions)

# Create the ridgeline (joy) plot
ggplot(fstats_df, aes(x = fstats, y = version, fill = version)) +
  geom_density_ridges(alpha = 0.7, scale = 1.5, rel_min_height = 0.01, color = "white") +
  scale_fill_brewer(palette = "Set1") +  # or use scale_fill_manual() if you have specific colors
  labs(x = "F-statistic", y = "Version", title = "") +
  facet_wrap(~ country) +
  theme_tufte(base_family = "Helvetica") +
  theme(legend.position = "none",
        strip.text = element_text(size = 12))


# Save the plot
ggsave("plots/figH1.png",
       units = "in",
       width = 12,
       height = 5,
       dpi = 300)
